{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is a naive implementation of network reconstruction via Gumbel-softmax trick."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import networkx as nx\n",
    "%matplotlib inline\n",
    "import torch\n",
    "from torch.nn.functional import gumbel_softmax\n",
    "import scipy\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Coupled Map Lattice (CML) dynamics on a random k-regular graph. Here we do not train a GNN and we assume that we know the exact form of dynamics. We use current state to predict the next state. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    def logistic_map(x, lambd):\n",
    "        return lambd * x * (1 - x)\n",
    "\n",
    "    def one_step_dynamics(current_state, A, s=0.2, lambd=3.8, eps=0, device='cpu'):\n",
    "        \"\"\"Given x_t (n by 1) and coupling A, produce x_t+1\"\"\"\n",
    "        diags = A.sum(1)\n",
    "        if isinstance(current_state, np.ndarray):\n",
    "            noise = np.random.randn(n, 1) * eps\n",
    "            diags[np.where(diags==0)] = 1 # avoid nan\n",
    "        else:\n",
    "            noise = torch.randn(n, 1) * eps\n",
    "            noise = noise.to(device)\n",
    "            diags = torch.where(diags!=0, diags, torch.ones(n, device=device))\n",
    "        next_state = (1 - s) * logistic_map(current_state, lambd=lambd) + s/diags * A @ logistic_map(current_state, lambd=lambd)\n",
    "        next_state += noise\n",
    "        return next_state\n",
    "\n",
    "    def gumbel_sample(x, temp=1, hard=False, eps=1e-5):\n",
    "        logp = x.view(-1, 2)\n",
    "        out = gumbel_softmax(logp, temp, hard)\n",
    "        out = out[:, 0].view(n, n)\n",
    "        return out\n",
    "\n",
    "    device = 'cuda:1'\n",
    "    n = 10\n",
    "    instances = 100\n",
    "    steps = 50\n",
    "    bs = 128\n",
    "    lr=1e-3\n",
    "    epochs = 800\n",
    "    tau =5\n",
    "\n",
    "    # data generation\n",
    "    G = nx.random_regular_graph(4, n)\n",
    "    A = nx.adjacency_matrix(G)\n",
    "    A = A.todense() \n",
    "    A = np.array(A) # adjacency matrix A\n",
    "    # For every initialization, we run 50 steps\n",
    "    s = np.random.rand(n, 1)\n",
    "    simulates = np.empty((instances*steps, n, 1))\n",
    "    for i in range(instances):\n",
    "        for j in range(steps):\n",
    "            s = one_step_dynamics(s, A=A)\n",
    "            simulates[i*steps+j, :, :] = s\n",
    "            if (j+1) % steps == 0:\n",
    "                s = np.random.rand(n, 1)\n",
    "    # We use current step to predict the next step, [# data, N, 2]\n",
    "    data = np.empty((1, n, 2))\n",
    "    for i in range(simulates.shape[0]):\n",
    "        if (i+1) % steps != 0:\n",
    "            temp = np.empty((1, n, 2))\n",
    "            temp[:, :, 0] = simulates[i, :, :].squeeze()\n",
    "            temp[:, :, 1] = simulates[i+1, :, :].squeeze()\n",
    "            data = np.concatenate((data, temp), axis=0)\n",
    "    data = torch.from_numpy(data[1:, :, :]).to(torch.float32)\n",
    "    data = data.to(device)\n",
    "    data_loader = DataLoader(data, batch_size=bs, shuffle=True)\n",
    "\n",
    "    x = -1 * torch.rand(n, n, 2, device=device) # unnormalized logits\n",
    "    x.requires_grad = True\n",
    "    optimizer = torch.optim.Adam([x], lr=lr)\n",
    "    loss_arr = []\n",
    "    error_arr = []\n",
    "\n",
    "    for epoch in range(epochs):\n",
    "        # Training\n",
    "        for idx, data in enumerate(data_loader): # data size [bs, n, 2]\n",
    "            A_p = gumbel_sample(x, hard=False, temp=tau)\n",
    "            s_t = torch.unsqueeze(data[:, :, 1], 2)\n",
    "            s_p = one_step_dynamics(torch.unsqueeze(data[:, :, 0], 2), A_p, device=device)\n",
    "            loss = torch.mean((s_t - s_p) ** 2)\n",
    "            loss_arr.append(loss.data.cpu().numpy())\n",
    "    #         print('Epoch: %i\\t Batch Num: %i\\t Loss: %.5f' % (epoch, idx, loss.data))\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        # Test\n",
    "        A_p = gumbel_sample(x, hard=True, temp=1)\n",
    "        A_p = A_p.data.cpu().numpy()\n",
    "        mask = np.ones((n, n))\n",
    "        np.fill_diagonal(mask, np.zeros(n))\n",
    "        error = np.abs(A_p * mask - A).sum()/(n**2-n)\n",
    "        error_arr.append(error)\n",
    "        if epoch % 100 == 0:\n",
    "            print('Epoch: %i\\t Error rate: %.5f' % (epoch, error))\n",
    "#             print(x)\n",
    "        if np.abs(error) < 0.0001:\n",
    "            break\n",
    "    return error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0\t Error rate: 0.43333\n",
      "Epoch: 0\t Error rate: 0.53333\n",
      "Epoch: 0\t Error rate: 0.51111\n",
      "Epoch: 0\t Error rate: 0.56667\n",
      "Epoch: 0\t Error rate: 0.47778\n",
      "Epoch: 0\t Error rate: 0.47778\n",
      "Epoch: 0\t Error rate: 0.44444\n",
      "Epoch: 100\t Error rate: nan\n",
      "Epoch: 200\t Error rate: nan\n",
      "Epoch: 300\t Error rate: nan\n",
      "Epoch: 400\t Error rate: nan\n",
      "Epoch: 500\t Error rate: nan\n",
      "Epoch: 600\t Error rate: nan\n",
      "Epoch: 700\t Error rate: nan\n",
      "Epoch: 0\t Error rate: 0.47778\n",
      "Epoch: 0\t Error rate: 0.58889\n",
      "Epoch: 0\t Error rate: 0.38889\n"
     ]
    }
   ],
   "source": [
    "results_arr = []\n",
    "for _ in range(10):\n",
    "    result = main()\n",
    "    results_arr.append(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.06436781609195402,\n",
       " nan,\n",
       " nan,\n",
       " nan,\n",
       " nan,\n",
       " nan,\n",
       " 0.0735632183908046,\n",
       " nan,\n",
       " nan,\n",
       " nan]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.0, nan, 0.0, nan, 0.0, nan, nan, nan, 0.0, nan]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.0,\n",
       " 0.0,\n",
       " 0.0,\n",
       " nan,\n",
       " 0.1111111111111111,\n",
       " 0.0,\n",
       " 0.022222222222222223,\n",
       " 0.0,\n",
       " 0.044444444444444446,\n",
       " 0.1111111111111111]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, nan, 0.0, 0.0, 0.0]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'loss_arr' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-37-6bb255748606>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_arr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;31m# plt.plot(error_arr)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'loss_arr' is not defined"
     ]
    }
   ],
   "source": [
    "plt.plot(loss_arr)\n",
    "# plt.plot(error_arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'error_arr' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-38-740ac2bc5168>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merror_arr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'error_arr' is not defined"
     ]
    }
   ],
   "source": [
    "plt.plot(error_arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'A_p' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-39-3889d12b1b26>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mA_p\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'A_p' is not defined"
     ]
    }
   ],
   "source": [
    "plt.imshow(A_p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'A' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-40-95bc0cbdc54b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mA\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'A' is not defined"
     ]
    }
   ],
   "source": [
    "plt.imshow(A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'A_p' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-41-c61030abc445>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mA_p\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'A_p' is not defined"
     ]
    }
   ],
   "source": [
    "plt.imshow(A_p - A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
